package tcode

import (
	"context"
	"errors"
	_ "github.com/go-sql-driver/mysql"
	"strconv"
	"strings"
)

type mysqlDbParse int8

func (*mysqlDbParse) currentDbName(ctx context.Context) (dbName string, err error) {
	err = NewSqlInfo(ctx, "SELECT DATABASE()").ToVar(&dbName)
	if err != nil {
		FuncLog(err)
	}
	return dbName, err
}

func (mysql *mysqlDbParse) selectTableName(ctx context.Context) ([]string, error) {
	dbName, err := mysql.currentDbName(ctx)
	if err != nil {
		FuncLog(err)
		return nil, err
	}
	table, err := NewSqlInfo(ctx, "SELECT table_name FROM information_schema.tables WHERE table_schema = ? ", dbName).ToRawTable(nil)
	if err != nil {
		FuncLog(err)
		return nil, err
	}
	tableNames := table.GetColumnByName("table_name")
	return tableNames, nil
}

func (mysql *mysqlDbParse) parseFields(ctx context.Context, fullyTableName string) ([]*fieldInfo, error) {
	table, err := NewSqlInfo(ctx, "SHOW FULL COLUMNS FROM "+fullyTableName).ToRawTable(nil)
	if err != nil {
		FuncLog(err)
		return nil, err
	}
	fields := make([]*fieldInfo, 0)
	for i := 0; i < table.RowCount(); i++ {
		row := table.GetRow(i)
		fi := fieldInfo{}
		fi.ColumnName = row["Field"]
		fi.ColumnType = row["Type"]
		fi.Collection = row["Collation"]
		fi.DefaultValue = row["Default"]
		fi.Comment = row["Comment"]
		fi.Name = Hump(fi.ColumnName)
		fi.JsonName = PrefixLower(fi.Name)
		fi.Type = fi.ColumnType
		fi.NotNull = strings.EqualFold(row["Null"], "NO")
		fi.PrimaryKey = strings.EqualFold(row["Key"], "PRI")
		fi.AutoIncrement = strings.EqualFold(row["Extra"], "AUTO_INCREMENT")
		err = mysql.columnTypeParse(&fi)
		if err != nil {
			FuncLog(err)
			return nil, err
		}
		if listenField != nil {
			listenField(&fi)
		}
		fields = append(fields, &fi)
	}
	// 遍历结果集并打印信息
	return fields, nil
}

func (mysql *mysqlDbParse) ToStructInfo(ctx context.Context, fullyTableName, tableComment string) (*structInfo, error) {
	var dbName, tableName string
	if strings.Index(fullyTableName, ".") > 0 {
		dbName = fullyTableName[0:strings.Index(fullyTableName, ".")]
		tableName = fullyTableName[strings.Index(fullyTableName, ".")+1:]
	} else {
		err := errors.New("please entry full table name; tip: [database.table]")
		FuncLog(err)
		return nil, err
	}
	si := &structInfo{}

	if len(tableComment) <= 0 {
		err := NewSqlInfo(ctx, "SELECT TABLE_COMMENT FROM information_schema.tables WHERE table_schema = ? AND table_name = ? ", dbName, tableName).ToVar(&tableComment)
		if err != nil {
			FuncLog(err)
			return nil, err
		}
	}
	si.Comment = tableComment
	fields, err := mysql.parseFields(ctx, fullyTableName)
	if err != nil {
		FuncLog(err)
		return nil, err
	}
	si.DbName = dbName
	si.TableName = tableName
	si.Name = Hump(tableName)
	si.FileName = si.Name + ".go"
	si.Fields = fields
	si.ImportPackages = make([]string, 0)
	si.PackageName = getContextConf(ctx).PackageName
	for _, field := range fields {
		if field.ImportPackage == "" {
			continue
		}
		if StringInSlice(field.ImportPackage, si.ImportPackages) {
			continue
		}
		si.ImportPackages = append(si.ImportPackages, field.ImportPackage)
	}
	return si, nil
}

func (mysql *mysqlDbParse) ToAllStructInfo(ctx context.Context) ([]*structInfo, error) {
	dbName, err := mysql.currentDbName(ctx)
	if err != nil {
		return nil, err
	}
	return mysql.ToAllStructInfoOtherDb(ctx, dbName)
}

func (mysql *mysqlDbParse) ToAllStructInfoOtherDb(ctx context.Context, dbName string) ([]*structInfo, error) {
	table, err := NewSqlInfo(ctx, "SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables WHERE table_schema = ?", dbName).ToRawTable(nil)
	if err != nil {
		return nil, err
	}
	sis := make([]*structInfo, 0)
	for i := 0; i < table.RowCount(); i++ {
		row := table.GetRow(i)
		si, err := mysql.ToStructInfo(ctx, dbName+"."+row["TABLE_NAME"], row["TABLE_COMMENT"])
		if err != nil {
			return nil, err
		}
		sis = append(sis, si)
	}
	return sis, nil
}

func (*mysqlDbParse) columnTypeParse(filedInfo *fieldInfo) (err error) {
	if len(filedInfo.ColumnType) <= 0 {
		err = errors.New("len(colType) <= 0")
		FuncLog(err)
		return err
	}
	filedInfo.Length = -1
	colType := strings.ToLower(filedInfo.ColumnType)
	if strings.Index(colType, "(") > 0 {
		filedInfo.Type = colType[0:strings.Index(colType, "(")]
	}
	fun := dataTypeMap[filedInfo.Type]
	if fun != nil {
		err = fun(filedInfo)
		if err != nil {
			FuncLog(err)
			return err
		}
		return
	}
	switch filedInfo.Type {
	case "tinyint":
		filedInfo.Type = "int8"
	case "tinyint unsigned":
		filedInfo.Type = "uint8"
	case "smallint", "mediumint": //数据校验报错提示越早越好
		filedInfo.Type = "int16"
	case "smallint unsigned", "mediumint unsigned":
		filedInfo.Type = "uint16"
	case "int":
		filedInfo.Type = "int32"
	case "int unsigned":
		filedInfo.Type = "uint32"
	case "bigint":
		filedInfo.Type = "int64"
	case "bigint unsigned":
		filedInfo.Type = "uint64"
	case "float":
		filedInfo.Type = "float32"
		s := colType[strings.Index(colType, "(")+1 : strings.Index(colType, ")")]
		filedInfo.Length, err = strconv.Atoi(s[0:strings.Index(s, ",")])
		if err != nil {
			FuncLog(err)
			return err
		}
		filedInfo.Accuracy, err = strconv.Atoi(s[strings.Index(s, ",")+1:])
		if err != nil {
			FuncLog(err)
			return err
		}
	case "double":
		filedInfo.Type = "float64"
		s := colType[strings.Index(colType, "(")+1 : strings.Index(colType, ")")]
		filedInfo.Length, err = strconv.Atoi(s[0:strings.Index(s, ",")])
		if err != nil {
			FuncLog(err)
			return err
		}
		filedInfo.Accuracy, err = strconv.Atoi(s[strings.Index(s, ",")+1:])
		if err != nil {
			FuncLog(err)
			return err
		}
	case "decimal":
		filedInfo.Type = "big.Float"
		filedInfo.ImportPackage = "math/big"
		s := colType[strings.Index(colType, "(")+1 : strings.Index(colType, ")")]
		filedInfo.Length, err = strconv.Atoi(s[0:strings.Index(s, ",")])
		if err != nil {
			FuncLog(err)
			return err
		}
		filedInfo.Accuracy, err = strconv.Atoi(s[strings.Index(s, ",")+1:])
		if err != nil {
			FuncLog(err)
			return err
		}
	case "char", "varchar":
		filedInfo.Type = "string"
		filedInfo.Length, err = strconv.Atoi(colType[strings.Index(colType, "(")+1 : strings.Index(colType, ")")])
		if err != nil {
			FuncLog(err)
			return err
		}
	case "tinytext":
		filedInfo.Type = "string"
		filedInfo.Length = 255
	case "text":
		filedInfo.Type = "string"
		filedInfo.Length = 65535
	case "mediumtext":
		filedInfo.Type = "string"
		filedInfo.Length = 16777215
	case "longtext":
		filedInfo.Type = "string"
		filedInfo.Length = 4294967295
	case "enum":
		filedInfo.Type = "string"
		filedInfo.Member = strings.Split(colType[strings.Index(colType, "(")+1:strings.Index(colType, ")")], ",")
		filedInfo.IsEnum = true
	case "set":
		filedInfo.Type = "string"
		filedInfo.Member = strings.Split(colType[strings.Index(colType, "(")+1:strings.Index(colType, ")")], ",")
		filedInfo.IsSet = true
	case "date", "time", "datetime", "timestamp":
		filedInfo.Type = "time.Time"
		filedInfo.ImportPackage = "time"
	case "year": // 1092<year<2156
		filedInfo.Type = "uint8"
		filedInfo.Length = 4
	case "json":
		filedInfo.Type = "string"
	case "binary", "varbinary":
		filedInfo.Type = "[]byte"
		filedInfo.Length, err = strconv.Atoi(colType[strings.Index(colType, "(")+1 : strings.Index(colType, ")")])
		if err != nil {
			FuncLog(err)
			return err
		}
	case "blob":
		filedInfo.Type = "[]byte"
		filedInfo.Length = 65535
	case "mediumblob":
		filedInfo.Type = "[]byte"
		filedInfo.Length = 16777215
	default: //other type
		filedInfo.Type = "any"
	}
	return nil
}
